{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n",
      "/home/fanyixing/.local/python3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88\n",
      "  return f(*args, **kwds)\n",
      "/home/fanyixing/.local/python3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88\n",
      "  return f(*args, **kwds)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "matchzoo version 2.1.0\n",
      "\n",
      "data loading ...\n",
      "data loaded as `train_pack_raw` `dev_pack_raw` `test_pack_raw`\n",
      "`ranking_task` initialized with metrics [normalized_discounted_cumulative_gain@3(0.0), normalized_discounted_cumulative_gain@5(0.0), mean_average_precision(0.0)]\n",
      "loading embedding ...\n",
      "embedding loaded as `glove_embedding`\n"
     ]
    }
   ],
   "source": [
    "%run init.ipynb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval: 100%|██████████| 2118/2118 [00:00<00:00, 3277.07it/s]\n",
      "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval: 100%|██████████| 18841/18841 [00:08<00:00, 2128.88it/s]\n",
      "Processing text_right with append: 100%|██████████| 18841/18841 [00:00<00:00, 443061.44it/s]\n",
      "Building FrequencyFilter from a datapack.: 100%|██████████| 18841/18841 [00:00<00:00, 76946.33it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 61087.65it/s]\n",
      "Processing text_left with extend: 100%|██████████| 2118/2118 [00:00<00:00, 357127.07it/s]\n",
      "Processing text_right with extend: 100%|██████████| 18841/18841 [00:00<00:00, 382711.17it/s]\n",
      "Building Vocabulary from a datapack.: 100%|██████████| 234249/234249 [00:00<00:00, 1785462.63it/s]\n",
      "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval: 100%|██████████| 2118/2118 [00:00<00:00, 3638.16it/s]\n",
      "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval: 100%|██████████| 18841/18841 [00:08<00:00, 2107.97it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 80677.64it/s]\n",
      "Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 144150.06it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 97328.59it/s]\n",
      "Processing length_left with len: 100%|██████████| 2118/2118 [00:00<00:00, 307271.83it/s]\n",
      "Processing length_right with len: 100%|██████████| 18841/18841 [00:00<00:00, 373124.96it/s]\n",
      "Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 49205.36it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 38907.05it/s]\n",
      "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval: 100%|██████████| 122/122 [00:00<00:00, 3540.82it/s]\n",
      "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval: 100%|██████████| 1115/1115 [00:00<00:00, 2163.02it/s]\n",
      "Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 8943.66it/s]\n",
      "Processing text_left with transform: 100%|██████████| 122/122 [00:00<00:00, 72060.99it/s]\n",
      "Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 100249.71it/s]\n",
      "Processing length_left with len: 100%|██████████| 122/122 [00:00<00:00, 116323.05it/s]\n",
      "Processing length_right with len: 100%|██████████| 1115/1115 [00:00<00:00, 329040.24it/s]\n",
      "Processing text_left with transform: 100%|██████████| 122/122 [00:00<00:00, 36927.55it/s]\n",
      "Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 38826.80it/s]\n",
      "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval: 100%|██████████| 237/237 [00:00<00:00, 3748.05it/s]\n",
      "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval: 100%|██████████| 2300/2300 [00:01<00:00, 2106.07it/s]\n",
      "Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 89301.64it/s]\n",
      "Processing text_left with transform: 100%|██████████| 237/237 [00:00<00:00, 91550.01it/s]\n",
      "Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 99896.44it/s]\n",
      "Processing length_left with len: 100%|██████████| 237/237 [00:00<00:00, 183879.03it/s]\n",
      "Processing length_right with len: 100%|██████████| 2300/2300 [00:00<00:00, 376464.36it/s]\n",
      "Processing text_left with transform: 100%|██████████| 237/237 [00:00<00:00, 44442.71it/s]\n",
      "Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 40030.45it/s]\n"
     ]
    }
   ],
   "source": [
    "preprocessor = mz.preprocessors.BasicPreprocessor(fixed_length_left=10, fixed_length_right=100, remove_stop_words=True)\n",
    "train_pack_processed = preprocessor.fit_transform(train_pack_raw)\n",
    "valid_pack_processed = preprocessor.transform(dev_pack_raw)\n",
    "test_pack_processed = preprocessor.transform(test_pack_raw)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'filter_unit': <matchzoo.preprocessors.units.frequency_filter.FrequencyFilter at 0x7ff0d35a0f98>,\n",
       " 'vocab_unit': <matchzoo.preprocessors.units.vocabulary.Vocabulary at 0x7ff0d4f9b0f0>,\n",
       " 'vocab_size': 16546,\n",
       " 'embedding_input_dim': 16546,\n",
       " 'input_shapes': [(10,), (100,)]}"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "preprocessor.context"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "text_left (InputLayer)          (None, 10)           0                                            \n",
      "__________________________________________________________________________________________________\n",
      "text_right (InputLayer)         (None, 100)          0                                            \n",
      "__________________________________________________________________________________________________\n",
      "embedding (Embedding)           multiple             4963800     text_left[0][0]                  \n",
      "                                                                 text_right[0][0]                 \n",
      "__________________________________________________________________________________________________\n",
      "conv1d_1 (Conv1D)               (None, 10, 128)      115328      embedding[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "conv1d_2 (Conv1D)               (None, 100, 128)     115328      embedding[1][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "max_pooling1d_1 (MaxPooling1D)  (None, 2, 128)       0           conv1d_1[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "max_pooling1d_2 (MaxPooling1D)  (None, 25, 128)      0           conv1d_2[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "flatten_1 (Flatten)             (None, 256)          0           max_pooling1d_1[0][0]            \n",
      "__________________________________________________________________________________________________\n",
      "flatten_2 (Flatten)             (None, 3200)         0           max_pooling1d_2[0][0]            \n",
      "__________________________________________________________________________________________________\n",
      "concatenate_1 (Concatenate)     (None, 3456)         0           flatten_1[0][0]                  \n",
      "                                                                 flatten_2[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "dropout_1 (Dropout)             (None, 3456)         0           concatenate_1[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "dense_1 (Dense)                 (None, 100)          345700      dropout_1[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "dense_2 (Dense)                 (None, 1)            101         dense_1[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "dense_3 (Dense)                 (None, 1)            2           dense_2[0][0]                    \n",
      "==================================================================================================\n",
      "Total params: 5,540,259\n",
      "Trainable params: 5,540,259\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model = mz.models.ArcI()\n",
    "model.params.update(preprocessor.context)\n",
    "model.params['task'] = ranking_task\n",
    "model.params['embedding_output_dim'] = glove_embedding.output_dim\n",
    "model.params['num_blocks'] = 1\n",
    "model.params['left_filters'] = [128]\n",
    "model.params['left_kernel_sizes'] = [3]\n",
    "model.params['left_pool_sizes'] = [4]\n",
    "model.params['right_filters'] = [128]\n",
    "model.params['right_kernel_sizes'] = [3]\n",
    "model.params['right_pool_sizes'] = [4]\n",
    "model.params['conv_activation_func']= 'relu'\n",
    "model.params['mlp_num_layers'] = 1\n",
    "model.params['mlp_num_units'] = 100\n",
    "model.params['mlp_num_fan_out'] = 1 \n",
    "model.params['mlp_activation_func'] = 'relu' \n",
    "model.params['dropout_rate'] = 0.9\n",
    "model.params['optimizer'] = 'adadelta'\n",
    "model.guess_and_fill_missing_params()\n",
    "model.build()\n",
    "model.compile()\n",
    "model.backend.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "embedding_matrix = glove_embedding.build_matrix(preprocessor.context['vocab_unit'].state['term_index'])\n",
    "model.load_embedding_matrix(embedding_matrix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_x, pred_y = test_pack_processed.unpack()\n",
    "evaluate = mz.callbacks.EvaluateAllMetrics(model, x=pred_x, y=pred_y, batch_size=len(pred_y))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num batches: 102\n"
     ]
    }
   ],
   "source": [
    "train_generator = mz.DataGenerator(\n",
    "    train_pack_processed,\n",
    "    mode='pair',\n",
    "    num_dup=2,\n",
    "    num_neg=1,\n",
    "    batch_size=20\n",
    ")\n",
    "print('num batches:', len(train_generator))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/30\n",
      "102/102 [==============================] - 12s 113ms/step - loss: 0.9915\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6305896279538845 - normalized_discounted_cumulative_gain@5(0.0): 0.6776773015027755 - mean_average_precision(0.0): 0.633078259024834\n",
      "Epoch 2/30\n",
      "102/102 [==============================] - 16s 153ms/step - loss: 0.9609\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.6414052424694674 - normalized_discounted_cumulative_gain@5(0.0): 0.6896110630787813 - mean_average_precision(0.0): 0.655464619746667\n",
      "Epoch 3/30\n",
      "102/102 [==============================] - 15s 147ms/step - loss: 0.9213\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.609449432076622 - normalized_discounted_cumulative_gain@5(0.0): 0.6535863921295467 - mean_average_precision(0.0): 0.6256401326730503\n",
      "Epoch 4/30\n",
      "102/102 [==============================] - 14s 139ms/step - loss: 0.8644\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5668504718763049 - normalized_discounted_cumulative_gain@5(0.0): 0.6184823173536319 - mean_average_precision(0.0): 0.5887412898235421\n",
      "Epoch 5/30\n",
      "102/102 [==============================] - 14s 140ms/step - loss: 0.8046\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5665334463767309 - normalized_discounted_cumulative_gain@5(0.0): 0.6176058113896864 - mean_average_precision(0.0): 0.5852152847278305\n",
      "Epoch 6/30\n",
      "102/102 [==============================] - 15s 148ms/step - loss: 0.8133\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5653296421157549 - normalized_discounted_cumulative_gain@5(0.0): 0.6098561773567537 - mean_average_precision(0.0): 0.5787384301794942\n",
      "Epoch 7/30\n",
      "102/102 [==============================] - 15s 144ms/step - loss: 0.7223\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5115731066519837 - normalized_discounted_cumulative_gain@5(0.0): 0.5696043849325363 - mean_average_precision(0.0): 0.5309471148833632\n",
      "Epoch 8/30\n",
      "102/102 [==============================] - 13s 129ms/step - loss: 0.7452\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5654535528839739 - normalized_discounted_cumulative_gain@5(0.0): 0.6150188550188977 - mean_average_precision(0.0): 0.5855908181807582\n",
      "Epoch 9/30\n",
      "102/102 [==============================] - 15s 149ms/step - loss: 0.6732\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.537504491728362 - normalized_discounted_cumulative_gain@5(0.0): 0.5888879094450309 - mean_average_precision(0.0): 0.556922967842061\n",
      "Epoch 10/30\n",
      "102/102 [==============================] - 15s 146ms/step - loss: 0.6431\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5294134618498313 - normalized_discounted_cumulative_gain@5(0.0): 0.5900024900110544 - mean_average_precision(0.0): 0.5542640985018126\n",
      "Epoch 11/30\n",
      "102/102 [==============================] - 14s 134ms/step - loss: 0.5859\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5520633189989336 - normalized_discounted_cumulative_gain@5(0.0): 0.6108468663800319 - mean_average_precision(0.0): 0.5791519476377355\n",
      "Epoch 12/30\n",
      "102/102 [==============================] - 15s 146ms/step - loss: 0.5602\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.534624128426663 - normalized_discounted_cumulative_gain@5(0.0): 0.5889541538361915 - mean_average_precision(0.0): 0.551507001163675\n",
      "Epoch 13/30\n",
      "102/102 [==============================] - 15s 144ms/step - loss: 0.5450\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5627233027710186 - normalized_discounted_cumulative_gain@5(0.0): 0.6108236823978452 - mean_average_precision(0.0): 0.584069096207155\n",
      "Epoch 14/30\n",
      "102/102 [==============================] - 15s 144ms/step - loss: 0.5581\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5513529195409487 - normalized_discounted_cumulative_gain@5(0.0): 0.5999350241763287 - mean_average_precision(0.0): 0.568106916524638\n",
      "Epoch 15/30\n",
      "102/102 [==============================] - 16s 156ms/step - loss: 0.4980\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5546799339708178 - normalized_discounted_cumulative_gain@5(0.0): 0.6152275605059057 - mean_average_precision(0.0): 0.5795762526981569\n",
      "Epoch 16/30\n",
      "102/102 [==============================] - 14s 141ms/step - loss: 0.5071\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5600349988760164 - normalized_discounted_cumulative_gain@5(0.0): 0.6072060773562872 - mean_average_precision(0.0): 0.5774453145256668\n",
      "Epoch 17/30\n",
      "102/102 [==============================] - 14s 140ms/step - loss: 0.4518\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.554905029677223 - normalized_discounted_cumulative_gain@5(0.0): 0.6021378901796073 - mean_average_precision(0.0): 0.5722005742097239\n",
      "Epoch 18/30\n",
      "102/102 [==============================] - 15s 145ms/step - loss: 0.4292\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5538251855689398 - normalized_discounted_cumulative_gain@5(0.0): 0.6006882253891397 - mean_average_precision(0.0): 0.5649684864479169\n",
      "Epoch 19/30\n",
      "102/102 [==============================] - 12s 116ms/step - loss: 0.4222\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5502126537672055 - normalized_discounted_cumulative_gain@5(0.0): 0.5933742887299561 - mean_average_precision(0.0): 0.5631647115191418\n",
      "Epoch 20/30\n",
      "102/102 [==============================] - 14s 142ms/step - loss: 0.3871\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.545929755381746 - normalized_discounted_cumulative_gain@5(0.0): 0.5965961908898312 - mean_average_precision(0.0): 0.5620997683843287\n",
      "Epoch 21/30\n",
      "102/102 [==============================] - 15s 145ms/step - loss: 0.3485\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.545785357053379 - normalized_discounted_cumulative_gain@5(0.0): 0.6002958901867365 - mean_average_precision(0.0): 0.5651984875273516\n",
      "Epoch 22/30\n",
      "102/102 [==============================] - 15s 152ms/step - loss: 0.3665\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.554029843099671 - normalized_discounted_cumulative_gain@5(0.0): 0.6048247555957358 - mean_average_precision(0.0): 0.5708237928362012\n",
      "Epoch 23/30\n",
      "102/102 [==============================] - 13s 128ms/step - loss: 0.3638\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5491694007852175 - normalized_discounted_cumulative_gain@5(0.0): 0.6068284143675057 - mean_average_precision(0.0): 0.568656261259232\n",
      "Epoch 24/30\n",
      "102/102 [==============================] - 13s 132ms/step - loss: 0.3220\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5501388396528246 - normalized_discounted_cumulative_gain@5(0.0): 0.6082083078502355 - mean_average_precision(0.0): 0.5698885871168076\n",
      "Epoch 25/30\n",
      "102/102 [==============================] - 14s 134ms/step - loss: 0.3111\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5529613793276678 - normalized_discounted_cumulative_gain@5(0.0): 0.6087770320216022 - mean_average_precision(0.0): 0.5719623784893997\n",
      "Epoch 26/30\n",
      "102/102 [==============================] - 15s 143ms/step - loss: 0.2875\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5535491988838098 - normalized_discounted_cumulative_gain@5(0.0): 0.607511521411928 - mean_average_precision(0.0): 0.562998351131385\n",
      "Epoch 27/30\n",
      "102/102 [==============================] - 14s 133ms/step - loss: 0.2849\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5469268899993501 - normalized_discounted_cumulative_gain@5(0.0): 0.6037117043525182 - mean_average_precision(0.0): 0.5606121264001305\n",
      "Epoch 28/30\n",
      "102/102 [==============================] - 14s 140ms/step - loss: 0.2623\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5550211003938021 - normalized_discounted_cumulative_gain@5(0.0): 0.6047800527522788 - mean_average_precision(0.0): 0.5628610719197386\n",
      "Epoch 29/30\n",
      "102/102 [==============================] - 14s 141ms/step - loss: 0.2663\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5457115429389982 - normalized_discounted_cumulative_gain@5(0.0): 0.6059713781830973 - mean_average_precision(0.0): 0.5609192818766827\n",
      "Epoch 30/30\n",
      "102/102 [==============================] - 15s 145ms/step - loss: 0.2661\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5475428253197103 - normalized_discounted_cumulative_gain@5(0.0): 0.610957175825584 - mean_average_precision(0.0): 0.5696311917780938\n"
     ]
    }
   ],
   "source": [
    "history = model.fit_generator(train_generator, epochs=30, callbacks=[evaluate], workers=30, use_multiprocessing=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
